import tiktoken
from functools import wraps
from filelock import FileLock

def token_count_decorator(func=None, model="gpt-4o-mini", flow="input", batch=True, clear=False):
    '''
    flow: [input, output, together]
    clear=True means delete previous cost
    ''' 
    def decorater(func):

        def cal_cost(input_token_num, output_token_num, real_model):
            if real_model == "gpt-4o-mini":
                fee = (0.15 * input_token_num + 0.6 * output_token_num) / 1000000.0
            elif real_model == "gpt-4o":
                fee = (5 * input_token_num + 15 * output_token_num) / 1000000.0
            elif real_model == "text-embedding-3-large":
                fee = 0.13 * input_token_num / 1000000.0
            elif real_model == "text-embedding-3-small":
                fee = 0.02 * input_token_num / 1000000.0
            elif real_model == "text-embedding-ada-002":
                fee = 0.1 * input_token_num / 1000000.0
            else:
                fee = (1.5 * input_token_num + 2 * output_token_num) / 1000000.0
            
            if batch:
                fee = fee * 0.5
            
            return fee
        
        def record(input_token_num, output_token_num, fee):
            token_count_path = "token_count.txt"
            lock_path = token_count_path + ".lock"
            lock = FileLock(lock_path)
            with lock:
                try:
                    if clear:
                        input_token_num_tmp, output_token_num_tmp, fee_tmp = 0, 0, 0.0
                    else:
                        with open(token_count_path, 'r') as file:
                            input_token_num_tmp = int(file.readline())
                            output_token_num_tmp = int(file.readline())
                            fee_tmp = float(file.readline())
                except FileNotFoundError:
                    input_token_num_tmp, output_token_num_tmp, fee_tmp = 0, 0, 0.0

                with open(token_count_path, 'w') as file:
                    file.write(str(input_token_num_tmp + input_token_num) + '\n')
                    file.write(str(output_token_num_tmp + output_token_num) + '\n')
                    file.write(str(fee_tmp + fee) + '\n')
        
        @wraps(func)
        def wrapper(*args, **kwargs):

            def count_tokens(text):
                return len(enc.encode(text))
            
            real_model = kwargs.get("gpt_model", model)
            enc = tiktoken.encoding_for_model(real_model)
            input_tokens, output_tokens = 0, 0
            if flow in ["input", "together"]:
                for arg in args:
                    if isinstance(arg, str):
                        input_tokens += count_tokens(arg)
                for _, value in kwargs.items():
                    if isinstance(value, str):
                        input_tokens += count_tokens(value)
            
            result = func(*args, **kwargs)

            if flow in ["output", "together"]:
                output_tokens += count_tokens(result)

            fee = cal_cost(input_tokens, output_tokens, real_model=real_model)
            record(input_tokens, output_tokens, fee)
            
            return result
        
        return wrapper

    if func is None:    # with argument
        return decorater
    else:   # without argument
        return decorater(func)